#include <hip/hip_runtime.h>
#include "aiter_hip_common.h"
#include <hip/hip_bf16.h>
#include "{{triton_header}}"


struct __attribute__((packed)) KernelArgs
{
    void *ptr_R;
    p2 _p0;
    void *ptr_LSE;
    p2 _p1;
    void *ptr_Q;
    p2 _p2;
    void *ptr_KV;
    p2 _p3;
    void *ptr_LTP;
    p2 _p4;
    void *ptr_LTD;
    p2 _p5;
    void *ptr_LTL;
    p2 _p6;
    float scalar;
    p3 _p12;
    unsigned int s_MQA;
    p3 _p13;
    unsigned int s_kv_split;
    p3 _p14;
    unsigned int s_Q_Bs;
    p3 _p15;
    unsigned int s_Bs;
    p3 _p16;
    unsigned int s_log2_plen;
    p3 _p17;
    void *ptr_QTP;
    p2 _p18;
};

unsigned char hsaco[{{bin_size}}] = { {{bin_data}} };

extern "C" {
void {{func_name}}(void* Q,                 //   [num_seqs, num_heads, head_size]
                    void* KV,                //   [num_page, page_size, num_kv_heads, head_size]
                    void* qo_indptr,         //   [batch_size+1]
                    void* kv_indptr,         //   [batch_size+1]
                    void* kv_page_indices,   //   [num_page_used]
                    void* kv_last_page_lens, //   [batch_size]
                    int max_seqlen_q,
                    float softmax_scale,
                    // following are output
                    void* logits, //[batch_size, num_kv_splits, num_heads, v_head_dim]
                    void* attn_lse,   //[batch_size, num_kv_splits, num_heads,  1]
                    void* output,
                    int num_seqs,
                    int num_heads,
                    int num_kv_heads,
                    int stride_Q,
                    int stride_Page,
                    int attn_lse_stride_0,
                    int attn_lse_stride_1,
                    int attn_lse_stride_2,
                    int output_stride_0,
                    int output_stride_1,
                    void* stream
                    );
}


void {{func_name}}(void* Q,                 //   [num_seqs, num_heads, head_size]
                    void* KV,                //   [num_page, page_size, num_kv_heads, head_size]
                    void* qo_indptr,         //   [batch_size+1]
                    void* kv_indptr,         //   [batch_size+1]
                    void* kv_page_indices,   //   [num_page_used]
                    void* kv_last_page_lens, //   [batch_size]
                    int max_seqlen_q,
                    float softmax_scale,
                    // following are output
                    void* logits, //[batch_size, num_kv_splits, num_heads, v_head_dim]
                    void* attn_lse,   //[batch_size, num_kv_splits, num_heads,  1]
                    void* output,
                    int num_seqs,
                    int num_heads,
                    int num_kv_heads,
                    int q_stride_0,
                    int kv_stride_0,
                    int attn_lse_stride_0,
                    int attn_lse_stride_1,
                    int attn_lse_stride_2,
                    int output_stride_0,
                    int output_stride_1,
                    void* stream
                    ){
    static AiterAsmKernelFast impl("{{kernel_name}}", hsaco);
    constexpr int page_size = {{page_size}};
    uint32_t log2_page = (uint32_t)log2f(page_size);
    const int gqa_ratio = num_heads / num_kv_heads;
    KernelArgs args;
    size_t arg_size = sizeof(args);
    args.ptr_R = logits;
    args.ptr_LSE = attn_lse;
    args.ptr_Q = Q;
    args.ptr_KV = KV;
    args.ptr_LTP = kv_indptr;
    args.ptr_LTD = kv_page_indices;
    args.ptr_LTL = kv_last_page_lens;
    args.ptr_QTP = qo_indptr;
    args.scalar = softmax_scale;
    args.s_MQA = gqa_ratio * max_seqlen_q;
    args.s_kv_split = {{num_kv_splits}};
    args.s_Q_Bs = q_stride_0 * sizeof({{q_dtype}}) * max_seqlen_q;
    args.s_Bs = kv_stride_0 * sizeof({{kv_dtype}});
    args.s_log2_plen = log2_page;
    int sub_Q = gqa_ratio == 128 ? 128 : 16;
    impl.launch_kernel({&args,
                        &arg_size,
                        (gqa_ratio * max_seqlen_q + sub_Q - 1) / sub_Q, // gdx
                        num_seqs,              // gdy
                        {{num_kv_splits}},              // gdz
                        256,                   // bdx: 4 wv64
                        1,                     // bdy
                        1,                     // bdz
                        reinterpret_cast<hipStream_t>(stream)});
    if(num_heads == 128 && {{num_kv_splits}} == 1){
        return;
    }
    {{triton_kernel}}(reinterpret_cast<hipStream_t>(stream), logits, attn_lse, output, qo_indptr, kv_indptr, attn_lse_stride_0, attn_lse_stride_2, attn_lse_stride_1, output_stride_0, output_stride_1, num_seqs, num_heads, max_seqlen_q);
}